{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1a558ee1",
   "metadata": {},
   "source": [
    "<a href=\"https://colab.research.google.com/github/jerryjliu/llama_index/blob/main/docs/examples/index_structs/struct_indices/SQLIndexDemo.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "e45f9b60-cd6b-4c15-958f-1feca5438128",
   "metadata": {},
   "source": [
    "# Text-to-SQL Guide (Query Engine + Retriever)\n",
    "\n",
    "This is a basic guide to LlamaIndex's Text-to-SQL capabilities. \n",
    "1. We first show how to perform text-to-SQL over a toy dataset: this will do \"retrieval\" (sql query over db) and \"synthesis\".\n",
    "2. We then show how to buid a TableIndex over the schema to dynamically retrieve relevant tables during query-time.\n",
    "3. We finally show you how to define a text-to-SQL retriever on its own. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f3f3baa",
   "metadata": {},
   "source": [
    "If you're opening this Notebook on colab, you will probably need to install LlamaIndex 🦙."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2480c7af",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install llama-index-llms-openai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02b4e312",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install llama-index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f69be8d-99ae-4d9f-91e4-b90fc62bcf2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import openai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e28ca96e-f98f-4a72-9fe3-a372dbd08a8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"OPENAI_API_KEY\"] = \"sk-..\"\n",
    "openai.api_key = os.environ[\"OPENAI_API_KEY\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "119eb42b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import logging\n",
    "# import sys\n",
    "\n",
    "# logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n",
    "# logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "107396a9-4aa7-49b3-9f0f-a755726c19ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Markdown, display"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "461438c8-302d-45c5-8e69-16ad604686d1",
   "metadata": {},
   "source": [
    "### Create Database Schema\n",
    "\n",
    "We use `sqlalchemy`, a popular SQL database toolkit, to create an empty `city_stats` Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a370b266-66f5-4624-bbf9-2ad57f0511f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sqlalchemy import (\n",
    "    create_engine,\n",
    "    MetaData,\n",
    "    Table,\n",
    "    Column,\n",
    "    String,\n",
    "    Integer,\n",
    "    select,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea24f794-f10b-42e6-922d-9258b7167405",
   "metadata": {},
   "outputs": [],
   "source": [
    "engine = create_engine(\"sqlite:///:memory:\")\n",
    "metadata_obj = MetaData()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4154b29-7e23-4c26-a507-370a66186ae7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# create city SQL table\n",
    "table_name = \"city_stats\"\n",
    "city_stats_table = Table(\n",
    "    table_name,\n",
    "    metadata_obj,\n",
    "    Column(\"city_name\", String(16), primary_key=True),\n",
    "    Column(\"population\", Integer),\n",
    "    Column(\"country\", String(16), nullable=False),\n",
    ")\n",
    "metadata_obj.create_all(engine)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "1c09089a-6bcd-48db-8120-a84c8da3f82e",
   "metadata": {},
   "source": [
    "### Define SQL Database\n",
    "\n",
    "We first define our `SQLDatabase` abstraction (a light wrapper around SQLAlchemy). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "768d1581-b482-4c73-9963-5ffd68a2aafb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import SQLDatabase\n",
    "from llama_index.llms.openai import OpenAI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bffabba0-8e54-4f24-ad14-2c8979c582a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = OpenAI(temperature=0.1, model=\"gpt-3.5-turbo\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9432787b-a8f0-4fc3-8323-e2cd9497df73",
   "metadata": {},
   "outputs": [],
   "source": [
    "sql_database = SQLDatabase(engine, include_tables=[\"city_stats\"])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "bad7ffbe",
   "metadata": {},
   "source": [
    "We add some testing data to our SQL database."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95043e10-6cdf-4f66-96bd-ce307ea7df3e",
   "metadata": {},
   "outputs": [],
   "source": [
    "sql_database = SQLDatabase(engine, include_tables=[\"city_stats\"])\n",
    "from sqlalchemy import insert\n",
    "\n",
    "rows = [\n",
    "    {\"city_name\": \"Toronto\", \"population\": 2930000, \"country\": \"Canada\"},\n",
    "    {\"city_name\": \"Tokyo\", \"population\": 13960000, \"country\": \"Japan\"},\n",
    "    {\n",
    "        \"city_name\": \"Chicago\",\n",
    "        \"population\": 2679000,\n",
    "        \"country\": \"United States\",\n",
    "    },\n",
    "    {\"city_name\": \"Seoul\", \"population\": 9776000, \"country\": \"South Korea\"},\n",
    "]\n",
    "for row in rows:\n",
    "    stmt = insert(city_stats_table).values(**row)\n",
    "    with engine.begin() as connection:\n",
    "        cursor = connection.execute(stmt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b315b8ff-7dd7-4e7d-ac47-8c5a0c3e7ae9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[('Toronto', 2930000, 'Canada'), ('Tokyo', 13960000, 'Japan'), ('Chicago', 2679000, 'United States'), ('Seoul', 9776000, 'South Korea')]\n"
     ]
    }
   ],
   "source": [
    "# view current table\n",
    "stmt = select(\n",
    "    city_stats_table.c.city_name,\n",
    "    city_stats_table.c.population,\n",
    "    city_stats_table.c.country,\n",
    ").select_from(city_stats_table)\n",
    "\n",
    "with engine.connect() as connection:\n",
    "    results = connection.execute(stmt).fetchall()\n",
    "    print(results)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "051a171f-8c97-40ed-ae17-4e3fa3785487",
   "metadata": {},
   "source": [
    "### Query Index"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "f6a2303f-3bae-4fa2-8750-03f9af747848",
   "metadata": {},
   "source": [
    "We first show how we can execute a raw SQL query, which directly executes over the table."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eddd3608-31ff-4591-a02a-90987e312669",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "('Chicago',)\n",
      "('Seoul',)\n",
      "('Tokyo',)\n",
      "('Toronto',)\n"
     ]
    }
   ],
   "source": [
    "from sqlalchemy import text\n",
    "\n",
    "with engine.connect() as con:\n",
    "    rows = con.execute(text(\"SELECT city_name from city_stats\"))\n",
    "    for row in rows:\n",
    "        print(row)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "4e72b931",
   "metadata": {},
   "source": [
    "## Part 1: Text-to-SQL Query Engine\n",
    "Once we have constructed our SQL database, we can use the NLSQLTableQueryEngine to\n",
    "construct natural language queries that are synthesized into SQL queries.\n",
    "\n",
    "Note that we need to specify the tables we want to use with this query engine.\n",
    "If we don't the query engine will pull all the schema context, which could\n",
    "overflow the context window of the LLM."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d992fb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.query_engine import NLSQLTableQueryEngine\n",
    "\n",
    "query_engine = NLSQLTableQueryEngine(\n",
    "    sql_database=sql_database, tables=[\"city_stats\"], llm=llm\n",
    ")\n",
    "query_str = \"Which city has the highest population?\"\n",
    "response = query_engine.query(query_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c0dfe9c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<b>The city with the highest population is Tokyo.</b>"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "display(Markdown(f\"<b>{response}</b>\"))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "298b4ca2",
   "metadata": {},
   "source": [
    "This query engine should be used in any case where you can specify the tables you want\n",
    "to query over beforehand, or the total size of all the table schema plus the rest of\n",
    "the prompt fits your context window."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "dee4d251",
   "metadata": {},
   "source": [
    "## Part 2: Query-Time Retrieval of Tables for Text-to-SQL\n",
    "If we don't know ahead of time which table we would like to use, and the total size of\n",
    "the table schema overflows your context window size, we should store the table schema \n",
    "in an index so that during query time we can retrieve the right schema.\n",
    "\n",
    "The way we can do this is using the SQLTableNodeMapping object, which takes in a \n",
    "SQLDatabase and produces a Node object for each SQLTableSchema object passed \n",
    "into the ObjectIndex constructor.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d71045c0-7a96-4e86-b38c-c378b7759aa4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.indices.struct_store.sql_query import (\n",
    "    SQLTableRetrieverQueryEngine,\n",
    ")\n",
    "from llama_index.core.objects import (\n",
    "    SQLTableNodeMapping,\n",
    "    ObjectIndex,\n",
    "    SQLTableSchema,\n",
    ")\n",
    "from llama_index.core import VectorStoreIndex\n",
    "\n",
    "# set Logging to DEBUG for more detailed outputs\n",
    "table_node_mapping = SQLTableNodeMapping(sql_database)\n",
    "table_schema_objs = [\n",
    "    (SQLTableSchema(table_name=\"city_stats\"))\n",
    "]  # add a SQLTableSchema for each table\n",
    "\n",
    "obj_index = ObjectIndex.from_objects(\n",
    "    table_schema_objs,\n",
    "    table_node_mapping,\n",
    "    VectorStoreIndex,\n",
    ")\n",
    "query_engine = SQLTableRetrieverQueryEngine(\n",
    "    sql_database, obj_index.as_retriever(similarity_top_k=1)\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "b6156caf",
   "metadata": {},
   "source": [
    "Now we can take our SQLTableRetrieverQueryEngine and query it for our response."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "802da9ed",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "<b>The city with the highest population is Tokyo.</b>"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "response = query_engine.query(\"Which city has the highest population?\")\n",
    "display(Markdown(f\"<b>{response}</b>\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54a99cb0-578a-40ec-a3eb-1666ac18fbed",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('Tokyo',)]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# you can also fetch the raw result from SQLAlchemy!\n",
    "response.metadata[\"result\"]"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "0d19b9cd",
   "metadata": {},
   "source": [
    "You can also add additional context information for each table schema you define."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44a87651",
   "metadata": {},
   "outputs": [],
   "source": [
    "# manually set context text\n",
    "city_stats_text = (\n",
    "    \"This table gives information regarding the population and country of a\"\n",
    "    \" given city.\\nThe user will query with codewords, where 'foo' corresponds\"\n",
    "    \" to population and 'bar'corresponds to city.\"\n",
    ")\n",
    "\n",
    "table_node_mapping = SQLTableNodeMapping(sql_database)\n",
    "table_schema_objs = [\n",
    "    (SQLTableSchema(table_name=\"city_stats\", context_str=city_stats_text))\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2ef57746-7cbf-48cd-8781-9c048ce3d3c7",
   "metadata": {},
   "source": [
    "## Part 3: Text-to-SQL Retriever\n",
    "\n",
    "So far our text-to-SQL capability is packaged in a query engine and consists of both retrieval and synthesis.\n",
    "\n",
    "You can use the SQL retriever on its own. We show you some different parameters you can try, and also show how to plug it into our `RetrieverQueryEngine` to get roughly the same results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99128819-6acf-4717-b703-9d1ca41190a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.retrievers import NLSQLRetriever\n",
    "\n",
    "# default retrieval (return_raw=True)\n",
    "nl_sql_retriever = NLSQLRetriever(\n",
    "    sql_database, tables=[\"city_stats\"], return_raw=True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19bc6858-7e3b-4cf1-806f-0d6d63a55d61",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = nl_sql_retriever.retrieve(\n",
    "    \"Return the top 5 cities (along with their populations) with the highest population.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3576523f-0b5c-4c40-8b2a-d734bdfbde58",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "**Node ID:** 458f723e-f1ac-4423-917a-522a71763390<br>**Similarity:** None<br>**Text:** [('Tokyo', 13960000), ('Seoul', 9776000), ('Toronto', 2930000), ('Chicago', 2679000)]<br>"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from llama_index.core.response.notebook_utils import display_source_node\n",
    "\n",
    "for n in results:\n",
    "    display_source_node(n)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "acfdc307-65e1-4a0e-8098-f30c0db60c93",
   "metadata": {},
   "outputs": [],
   "source": [
    "# default retrieval (return_raw=False)\n",
    "nl_sql_retriever = NLSQLRetriever(\n",
    "    sql_database, tables=[\"city_stats\"], return_raw=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36ccfa42-51e0-49e7-afac-6a0d011af4e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = nl_sql_retriever.retrieve(\n",
    "    \"Return the top 5 cities (along with their populations) with the highest population.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cd14d5e-7a5c-4898-820a-469b80cee5c9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "**Node ID:** 7c0e4c94-c9a6-4917-aa3f-e3b3f4cbcd5c<br>**Similarity:** None<br>**Text:** <br>**Metadata:** {'city_name': 'Tokyo', 'population': 13960000}<br>"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/markdown": [
       "**Node ID:** 3c1d1caa-cec2-451e-8fd1-adc944e1d050<br>**Similarity:** None<br>**Text:** <br>**Metadata:** {'city_name': 'Seoul', 'population': 9776000}<br>"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/markdown": [
       "**Node ID:** fb9f9b25-b913-4dde-a0e3-6111f704aea9<br>**Similarity:** None<br>**Text:** <br>**Metadata:** {'city_name': 'Toronto', 'population': 2930000}<br>"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/markdown": [
       "**Node ID:** c31ba8e7-de5d-4f28-a464-5e0339547c70<br>**Similarity:** None<br>**Text:** <br>**Metadata:** {'city_name': 'Chicago', 'population': 2679000}<br>"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# NOTE: all the content is in the metadata\n",
    "for n in results:\n",
    "    display_source_node(n, show_source_metadata=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "392889f6-babe-4bce-b802-4379a9b3ca49",
   "metadata": {},
   "source": [
    "### Plug into our `RetrieverQueryEngine`\n",
    "\n",
    "We compose our SQL Retriever with our standard `RetrieverQueryEngine` to synthesize a response. The result is roughly similar to our packaged `Text-to-SQL` query engines."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f85e7a24-bbd0-4439-9da4-26a7a9e43b8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.query_engine import RetrieverQueryEngine\n",
    "\n",
    "query_engine = RetrieverQueryEngine.from_args(nl_sql_retriever)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50f2dec0-8700-49d3-81b7-69c21518ac87",
   "metadata": {},
   "outputs": [],
   "source": [
    "response = query_engine.query(\n",
    "    \"Return the top 5 cities (along with their populations) with the highest population.\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56eb7e5a-a23b-47ab-a154-98ce4c0d1c1c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The top 5 cities with the highest population are:\n",
      "\n",
      "1. Tokyo - 13,960,000\n",
      "2. Seoul - 9,776,000\n",
      "3. Toronto - 2,930,000\n",
      "4. Chicago - 2,679,000\n"
     ]
    }
   ],
   "source": [
    "print(str(response))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama_index_v2",
   "language": "python",
   "name": "llama_index_v2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
